import pickle
import random as rd
import numpy as np
import scipy.sparse as sp
from scipy.io import loadmat
import copy as cp
from sklearn.metrics import f1_score, accuracy_score, recall_score, roc_auc_score, average_precision_score,confusion_matrix
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from sklearn.manifold import TSNE
import torch 
import copy as cp
import os
from sklearn.metrics import confusion_matrix
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import (
    sort_edge_index, degree, add_remaining_self_loops, 
    remove_self_loops, get_laplacian, to_undirected, 
    to_dense_adj, to_networkx, dropout_adj
)
import networkx as nx
import random
import os
import os.path as osp
from torch_geometric.utils import to_scipy_sparse_matrix, to_dense_adj
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch.sparse import mm
import time
from functools import wraps

filelist = {
    'amz_upu': 'amz_upu_adjlists.pickle',
    'amz_usu': 'amz_usu_adjlists.pickle',
    'amz_uvu': 'amz_uvu_adjlists.pickle',
    'yelp_rsr': 'yelp_rsr_adjlists.pickle',
    'yelp_rtr': 'yelp_rtr_adjlists.pickle',
    'yelp_rur': 'yelp_rur_adjlists.pickle'
}

file_matrix_prefix = {
    'amz_upu': 'amazon_upu_matrix_',
    'amz_usu': 'amazon_usu_matrix_',
    'amz_uvu': 'amazon_uvu_matrix_',
    'yelp_rsr': 'yelpnet_rsr_matrix_decompision_',
    'yelp_rtr': 'yelpnet_rtr_matrix_decompision_',
    'yelp_rur': 'yelpnet_rur_matrix_decompision_'
}

def fetch_labeled_data(batch_nodes, labels, feat_data):
    """获取有标签数据
    Args:
        batch_nodes: 批次节点索引
        labels: 所有节点的标签
        feat_data: 节点特征矩阵
    Returns:
        features: 节点特征
        batch_labels: 节点标签
    """
    batch_labels = torch.tensor(labels[np.array(batch_nodes)]).long().to(feat_data.device)
    return feat_data, batch_labels
    
def make_loss_function(name, weight=None):
    if name == 'ce':
        return RobustCrossEntropyLoss()
    elif name == 'wce':
        return RobustCrossEntropyLoss(weight=weight)
    elif name == 'ce+dice':
        return DC_and_CE_loss()
    elif name == 'wce+dice':
        return DC_and_CE_loss(w_ce=weight)
    elif name == 'w_ce+dice':
        return DC_and_CE_loss(w_dc=weight, w_ce=weight)
    else:
        raise ValueError(name)

def EMA(cur_weight, past_weight, momentum=0.9):
    new_weight = momentum * past_weight + (1 - momentum) * cur_weight
    return new_weight

class DistDW:
    def __init__(self, num_cls, do_bg=False, momentum=0.95, device='cuda:0'):
        self.num_cls = num_cls
        self.do_bg = do_bg
        self.momentum = momentum
        self.device = device

    def _cal_weights(self, num_each_class):
        num_each_class = torch.FloatTensor(num_each_class).to(self.device)
        # 计算P并检查
        P = (num_each_class.max()+1e-8) / (num_each_class+1e-5)
        if torch.isnan(P).any():
            print(f"NaN detected in P: {P}")
            print(f"num_each_class: {num_each_class}")
            print(f"max value: {num_each_class.max()}")
        
        # 计算log并检查
        P_log = torch.log(P)
        if torch.isnan(P_log).any():
            print(f"NaN detected in P_log: {P_log}")
            print(f"P values: {P}")
        
        # 计算weight并检查
        max_val = P_log.max()
        if max_val == 0:
            weight = torch.ones_like(P_log)
        else:
            weight = P_log / max_val
            
        if torch.isnan(weight).any():
            print(f"NaN detected in weight: {weight}")
            print(f"P_log: {P_log}")
            print(f"P_log.max(): {P_log.max()}")
            # 使用安全值
            weight = torch.ones_like(weight)
            
        return weight

    def init_weights(self, labeled_dataset):
        """初始化基于类别分布的权重
        Args:
            labeled_dataset: 包含indices和labels属性的数据集对象
        """
        # 计算每个类别的样本数量
        num_each_class = np.zeros(self.num_cls)
        labels = np.array(labeled_dataset.labels)[labeled_dataset.indices]
    
        # 确保标签是整数类型
        labels = labels.astype(np.int64)
    
        # 计算每个类别的样本数量
        class_counts = np.bincount(labels, minlength=self.num_cls)
        num_each_class[:len(class_counts)] = class_counts

        # 计算权重
        weights = self._cal_weights(num_each_class)
        self.weights = weights * self.num_cls
        return self.weights.data.cpu().numpy()

    def get_ema_weights(self, pseudo_label):
        # 检查输入
        if torch.isnan(pseudo_label).any():
            print("NaN detected in pseudo_label input")
            return self.weights  # 返回上一次的权重
            
        pseudo_label = torch.argmax(pseudo_label.detach(), dim=1, keepdim=True).long()
        label_numpy = pseudo_label.data.cpu().numpy()
        
        # 检查类别分布
        num_each_class = np.zeros(self.num_cls)
        for i in range(label_numpy.shape[0]):
            label = label_numpy[i].reshape(-1)
            tmp, _ = np.histogram(label, range(self.num_cls + 1))
            num_each_class += tmp
            
        
        # 计算新权重
        cur_weights = self._cal_weights(num_each_class) * self.num_cls
        
        # EMA更新前检查
        if torch.isnan(cur_weights).any():
            print("NaN detected in cur_weights, using previous weights")
            return self.weights
            
        # EMA更新
        self.weights = EMA(cur_weights, self.weights, momentum=self.momentum)
        
        # 最终检查
        if torch.isnan(self.weights).any():
            print("NaN detected in final weights")
            self.weights = torch.ones_like(self.weights)
            
        return self.weights

# class DiffDW:
#     """Difficulty-aware Dynamic Weighting"""
#     def __init__(self, num_cls=2, momentum=0.95):
#         self.num_cls = num_cls
#         self.momentum = momentum
#         self.weights = torch.ones(num_cls) / num_cls
        
#     def cal_weights(self, outputs, targets):
#         """计算基于难度的动态权重"""
#         probs = F.softmax(outputs, dim=1)
#         batch_size = outputs.size(0)
#         weights = torch.zeros(self.num_cls).to(outputs.device)
        
#         # 计算每个类别的预测难度
#         for i in range(self.num_cls):
#             mask = (targets == i)
#             if mask.sum() > 0:
#                 # 预测正确的概率越低，说明越难，权重越大
#                 correct_probs = probs[mask, i]
#                 weights[i] = 1 - correct_probs.mean()
        
#         # 使用动量更新
#         weights = F.softmax(weights, dim=0)
#         self.weights = self.momentum * self.weights + (1 - self.momentum) * weights.cpu()
#         return self.weights.to(outputs.device)

# class DistDW:
#     """Distribution-aware Dynamic Weighting"""
#     def __init__(self, num_cls=2, momentum=0.95):
#         self.num_cls = num_cls
#         self.momentum = momentum
#         self.weights = torch.ones(num_cls) / num_cls
        
#     def get_ema_weights(self, outputs):
#         """计算基于分布的动态权重"""
#         probs = F.softmax(outputs, dim=1)
#         cls_dist = probs.mean(dim=0)  # 计算类别分布
        
#         # 基于类别分布的不平衡程度计算权重
#         weights = 1 / (cls_dist + 1e-6)  # 避免除零
#         weights = F.softmax(weights, dim=0)
        
#         # 使用动量更新
#         self.weights = self.momentum * self.weights + (1 - self.momentum) * weights.cpu()
#         return self.weights.to(outputs.device)

        
    
class DiffDW:
    def __init__(self, num_cls=2, accumulate_iters=20, device='cuda:0'):
        self.device = device
        self.last_dice = torch.zeros(num_cls).float().to(self.device) + 1e-8
        self.dice_func = SoftDiceLoss(smooth=1e-8, do_bg=True)
        self.cls_learn = torch.zeros(num_cls).float().to(self.device)
        self.cls_unlearn = torch.zeros(num_cls).float().to(self.device)
        self.num_cls = num_cls
        self.dice_weight = torch.ones(num_cls).float().to(self.device)
        self.accumulate_iters = accumulate_iters
        

    def init_weights(self):
        weights = np.ones(self.num_cls) * self.num_cls
        self.weights = torch.FloatTensor(weights).to(self.device)
        return weights

    def cal_weights(self, pred,  label):
        """计算基于难度的动态权重
        Args:
            pred: 预测结果 [batch_size, num_classes]
            label: 标签 [batch_size]
        """
        # 确保label是2D张量 [batch_size, 1]
        if label.dim() == 1:
            label = label.unsqueeze(1)
            
        x_onehot = torch.zeros(pred.shape).to(self.device)
        output = torch.argmax(pred, dim=1, keepdim=True).long()
        x_onehot.scatter_(1, output, 1)
        
        y_onehot = torch.zeros(pred.shape).to(self.device)
        y_onehot.scatter_(1, label, 1)
        
        cur_dice = self.dice_func(x_onehot, y_onehot, is_training=False)
        delta_dice = cur_dice - self.last_dice
        cur_cls_learn = torch.where(delta_dice>0, delta_dice, 0) * torch.log(cur_dice / self.last_dice)
        cur_cls_unlearn = torch.where(delta_dice<=0, delta_dice, 0) * torch.log(cur_dice / self.last_dice)
        self.last_dice = cur_dice
        self.cls_learn = EMA(cur_cls_learn, self.cls_learn, momentum=(self.accumulate_iters-1)/self.accumulate_iters)
        self.cls_unlearn = EMA(cur_cls_unlearn, self.cls_unlearn, momentum=(self.accumulate_iters-1)/self.accumulate_iters)
        cur_diff = (self.cls_unlearn + 1e-8) / (self.cls_learn + 1e-8)
        cur_diff = torch.pow(cur_diff, 1/5)
        self.dice_weight = EMA(1. - cur_dice, self.dice_weight, momentum=(self.accumulate_iters-1)/self.accumulate_iters)
        weights = cur_diff * self.dice_weight
        weights = weights / weights.max()
        return weights * self.num_cls




    
def fetch_labeled_data(batch_nodes, labels, feat_data):
    """获取有标签数据"""
    batch_labels = torch.tensor(labels[np.array(batch_nodes)]).long().to(feat_data.device)
    return feat_data, batch_labels

def fetch_unlabeled_data(batch):
    """获取无标签数据"""
    return batch['image_u'].cuda()

def get_current_consistency_weight(epoch, rampup_length=500):
    """获取当前一致性权重"""
    if rampup_length == 0:
        return 1.0
    else:
        current = np.clip(epoch, 0.0, rampup_length)
        phase = 1.0 - current / rampup_length
        return float(np.exp(-5.0 * phase * phase))
    
def get_current_mu(epoch, max_epochs=500):
    """获取当前mu值
    Args:
        epoch: 当前epoch
        max_epochs: 最大训练轮数
    Returns:
        当前epoch的mu值
    """
    mu = 2  # 固定mu值为2
    mu_rampup = True  # 默认使用rampup
    consistency_rampup = None  # 默认使用max_epochs
    
    if mu_rampup:
        if consistency_rampup is None:
            consistency_rampup = max_epochs
        return mu * sigmoid_rampup(epoch, consistency_rampup)
    else:
        return mu
    
def sigmoid_rampup(current, rampup_length):
    """Exponential rampup from https://arxiv.org/abs/1610.02242"""
    if rampup_length == 0:
        return 1.0
    else:
        current = np.clip(current, 0.0, rampup_length)
        phase = 1.0 - current / rampup_length
        return float(np.exp(-5.0 * phase * phase))
    
def create_data_loaders(idx_train, y_train, batch_size):
    """创建标签和无标签数据的批次索引
    Args:
        idx_train: 训练集索引
        y_train: 训练集标签
        batch_size: 批次大小
    Returns:
        labeled_idx: 有标签数据的索引
        unlabeled_idx: 无标签数据的索引
    """
    # Amazon数据集前3304个节点有标签
    labeled_idx = idx_train[:3304]
    unlabeled_idx = idx_train[3304:]
    
    return labeled_idx, unlabeled_idx

def feature_consistency_loss(feat_A, feat_B, weights=[0.2, 0.4, 0.6, 0.8]):
    """计算多层特征的一致性损失
    Args:
        feat_A: 模型A的特征列表
        feat_B: 模型B的特征列表
        weights: 不同层特征的权重
    """
    loss = 0
     # 确保输入是2D张量
    if feat_A.dim() == 1:
        feat_A = feat_A.unsqueeze(0)
        feat_B = feat_B.unsqueeze(0)
            
    # 1. 特征归一化
    feat_A = F.normalize(feat_A, p=2, dim=-1)
    feat_B = F.normalize(feat_B, p=2, dim=-1)
    # 2. 计算MSE损失
    loss = F.mse_loss(feat_A, feat_B)
    
    return loss


def create_data_loaders(idx_train, y_train, batch_size):
    """创建标签和无标签数据的批次索引
    Args:
        idx_train: 训练集索引
        y_train: 训练集标签
        batch_size: 批次大小
    Returns:
        labeled_idx: 有标签数据的索引
        unlabeled_idx: 无标签数据的索引
    """
    # Amazon数据集前3304个节点有标签
    labeled_idx = idx_train[:3304]
    unlabeled_idx = idx_train[3304:]
    
    return labeled_idx, unlabeled_idx

def get_batch_indices(sampled_idx_train, batch_size, batch, allow_repeat=True):
    """获取当前批次的索引
    Args:
        sampled_idx_train: 采样后的训练索引
        batch_size: 批次大小
        batch: 当前批次号
        allow_repeat: 是否允许重复采样，当数据不足时使用
    Returns:
        batch_nodes: 当前批次的节点索引
    """
    if not allow_repeat:
        # 原有的不重复采样逻辑
        i_start = batch * batch_size
        i_end = min((batch + 1) * batch_size, len(sampled_idx_train))
        batch_nodes = sampled_idx_train[i_start:i_end]
    else:
        # 始终进行重复采样，确保每个batch都能获得足够的labeled数据
        indices = np.random.choice(len(sampled_idx_train), batch_size, replace=True)
        batch_nodes = [sampled_idx_train[i] for i in indices]
    
    return batch_nodes

def update_ema_variables(model, ema_model, alpha=0.999):
    """更新EMA模型参数"""
    for ema_param, param in zip(ema_model.parameters(), model.parameters()):
        ema_param.data.mul_(alpha).add_(param.data, alpha=1 - alpha)

def calculate_g_mean(y_true, y_pred):

    cm = confusion_matrix(y_true, y_pred)

    TP = cm[1, 1]
    TN = cm[0, 0]
    FP = cm[0, 1]
    FN = cm[1, 0]
    

    sensitivity = TP / (TP + FN)
    specificity = TN / (TN + FP)
    

    g_mean = np.sqrt(sensitivity * specificity)
    return g_mean


def dict_to_edge_index(edge_dict):
    source_nodes = []
    target_nodes = []

    for src, targets in edge_dict.items():
        for target in targets:
            source_nodes.append(src)
            target_nodes.append(target)

    edge_index = [source_nodes, target_nodes]
    return torch.LongTensor(edge_index)



def numpy_array_to_edge_index(np_array):

    assert np_array.ndim == 2 and np_array.shape[0] == np_array.shape[1], "Input must be a square matrix."

    # Find the indices of nonzero elements (edges)
    rows, cols = np.nonzero(np_array)

    # Stack them to create edge index
    edge_index = np.vstack((rows, cols))

    # Convert to PyTorch tensor
    edge_index_tensor = torch.from_numpy(edge_index).long()

    return edge_index_tensor

def load_data_mask(data,k=2, prefix=''):
    """
    Load graph, feature, and label given dataset name
    """
    pickle_file = {}
    matrix_prefix = {}
    for key in filelist: # update the file paths
        pickle_file[key] = os.path.join(prefix, filelist[key])
        matrix_prefix[key] = os.path.join(prefix, file_matrix_prefix[key])
    
    if data == 'yelp':
        data_file = loadmat(os.path.join(prefix, 'YelpChi.mat'))
        labels = data_file['label'].flatten()
        feat_data = data_file['features'].todense().A
        
        with open(pickle_file['yelp_rur'], 'rb') as file:
            relation1 = pickle.load(file)
        file.close()
        relation1 = dict_to_edge_index(relation1)
        relation1_tree = []
        for i in range(1,k+1):
            file_name = '{}{}.pkl'.format(matrix_prefix['yelp_rur'], i)
            with open(file_name,'rb') as file:
                tree = pickle.load(file)
            file.close()
            relation1_tree.append(numpy_array_to_edge_index(tree))
        with open(pickle_file['yelp_rtr'], 'rb') as file:
            relation2 = pickle.load(file)
        file.close()
        relation2 = dict_to_edge_index(relation2)
        relation2_tree = []
        for i in range(1,k+1):
            file_name = '{}{}.pkl'.format(matrix_prefix['yelp_rtr'], i)
            with open(file_name,'rb') as file:
                tree = pickle.load(file)
            file.close()
            relation2_tree.append(numpy_array_to_edge_index(tree))
        with open(pickle_file['yelp_rsr'], 'rb') as file:
            relation3 = pickle.load(file)
        file.close()
        relation3 = dict_to_edge_index(relation3)
        relation3_tree = []
        for i in range(1,k+1):
            file_name = '{}{}.pkl'.format(matrix_prefix['yelp_rsr'], i)
            with open(file_name,'rb') as file:
                tree = pickle.load(file)
            file.close()
            relation3_tree.append(numpy_array_to_edge_index(tree))
        return [[relation1,relation1_tree],[relation2,relation2_tree],[relation3,relation3_tree]],feat_data,labels
    elif data == 'amazon':
        data_file = loadmat(os.path.join(prefix, 'Amazon.mat'))
        labels = data_file['label'].flatten()
        feat_data = data_file['features'].todense().A
        
        # 1. 加载原始数据
        # load the preprocessed adj_lists
        with open(pickle_file['amz_upu'], 'rb') as file:
            relation1 = pickle.load(file)
        relation1 = dict_to_edge_index(relation1)
        relation1_tree = []
        for i in range(1,k+1):
            file_name = '{}{}.pkl'.format(matrix_prefix['amz_upu'], i)
            with open(file_name,'rb') as file:
                tree = pickle.load(file)
            relation1_tree.append(numpy_array_to_edge_index(tree))
            
        with open(pickle_file['amz_usu'], 'rb') as file:
            relation2 = pickle.load(file)
        relation2 = dict_to_edge_index(relation2)
        relation2_tree = []
        for i in range(1,k+1):
            file_name = '{}{}.pkl'.format(matrix_prefix['amz_usu'], i)
            with open(file_name,'rb') as file:
                tree = pickle.load(file)
            relation2_tree.append(numpy_array_to_edge_index(tree))
            
        with open(pickle_file['amz_uvu'], 'rb') as file:
            relation3 = pickle.load(file)
        relation3 = dict_to_edge_index(relation3)
        relation3_tree = []
        for i in range(1,k+1):
            file_name = '{}{}.pkl'.format(matrix_prefix['amz_uvu'], i)
            with open(file_name,'rb') as file:
                tree = pickle.load(file)
            relation3_tree.append(numpy_array_to_edge_index(tree))
        
        original_edge_indexs = [[relation1,relation1_tree],
                              [relation2,relation2_tree],
                              [relation3,relation3_tree]]
        
        # 2. 加载前部分mask的数据
        front_edge_indexs = []
        for name, rel_name in [('upu','amz_upu'), ('usu','amz_usu'), ('uvu','amz_uvu')]:
            with open(os.path.join(prefix, 'front_masked', f'amz_{name}_adjlists.pickle'), 'rb') as file:
                relation = pickle.load(file)
            relation = dict_to_edge_index(relation)
            relation_tree = []
            for i in range(1,k+1):
                file_name = os.path.join(prefix, 'front_masked', f'{matrix_prefix[rel_name]}{i}.pkl')
                with open(file_name,'rb') as file:
                    tree = pickle.load(file)
                relation_tree.append(numpy_array_to_edge_index(tree))
            front_edge_indexs.append([relation, relation_tree])
            
        # 3. 加载后部分mask的数据
        back_edge_indexs = []
        for name, rel_name in [('upu','amz_upu'), ('usu','amz_usu'), ('uvu','amz_uvu')]:
            with open(os.path.join(prefix, 'back_masked', f'amz_{name}_adjlists.pickle'), 'rb') as file:
                relation = pickle.load(file)
            relation = dict_to_edge_index(relation)
            relation_tree = []
            for i in range(1,k+1):
                file_name = os.path.join(prefix, 'back_masked', f'{matrix_prefix[rel_name]}{i}.pkl')
                with open(file_name,'rb') as file:
                    tree = pickle.load(file)
                relation_tree.append(numpy_array_to_edge_index(tree))
            back_edge_indexs.append([relation, relation_tree])
            
        return original_edge_indexs, front_edge_indexs, back_edge_indexs, feat_data, labels
    elif data=='CCFD':
        assert False,'CCFD dataset is secret, please contact the author for the dataset.'
        
        data_file= loadmat(os.path.join(prefix, 'CCFD.mat'))
        labels = data_file['labels'].flatten()
        feat_data = data_file['features']
        with open('../data/net_source_CCFD.pickle', 'rb') as file:
            relation1 = pickle.load(file)
        file.close()
        relation1 = dict_to_edge_index(relation1)
        relation1_tree = []
        for i in range(1,k+1):
            file_name = f'../data/CCFD_r1_matrix_{k}.pkl'
            with open(file_name,'rb') as file:
                tree = pickle.load(file)
            file.close()
            relation1_tree.append(numpy_array_to_edge_index(tree))	
        return [[relation1,relation1_tree]],feat_data,labels
        


def Visualization(labels, embedding, prefix):
    train_pos, train_neg = pos_neg_split(list(range(len(labels))), labels)
    sampled_idx_train = undersample(train_pos, train_neg, scale=1)
    tsne = TSNE(n_components=2, random_state=43)
    sampled_idx_train = np.array(sampled_idx_train)
    sampled_idx_train = np.random.choice(sampled_idx_train, size=5000, replace=True)
    ps = embedding[sampled_idx_train]
    ls = labels[sampled_idx_train]

    X_reduced = tsne.fit_transform(ps)

    scaler = MinMaxScaler(feature_range=(0, 1))
    X_scaled = scaler.fit_transform(X_reduced)
    print(X_scaled.shape)
    
    plt.figure(figsize=(8, 8))

    plt.scatter(X_scaled[ls == 0, 0], X_scaled[ls == 0, 1], c='#14517C', label='Label 0', s=3)

    plt.scatter(X_scaled[ls == 1, 0], X_scaled[ls == 1, 1], c='#FA7F6F', label='Label 1', s=3)

    ax = plt.gca()
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.spines['bottom'].set_visible(False)

    plt.xticks([])
    plt.yticks([])

    plt.xlim(0, 1)
    plt.ylim(0, 1)
    filepath = os.path.join(prefix, 'HOGRL.png')
    plt.savefig(filepath)
    plt.show()
    
def normalize(mx):
    """
        Row-normalize sparse matrix
        Code from https://github.com/williamleif/graphsage-simple/
    """
    rowsum = np.array(mx.sum(1)) + 0.01
    r_inv = np.power(rowsum, -1).flatten()
    r_inv[np.isinf(r_inv)] = 0.
    r_mat_inv = sp.diags(r_inv)
    mx = r_mat_inv.dot(mx)
    return mx



def pos_neg_split(nodes, labels):
    """
    Find positive and negative nodes given a list of nodes and their labels
    :param nodes: a list of nodes
    :param labels: a list of node labels
    :returns: the spited positive and negative nodes
    """
    pos_nodes = []
    neg_nodes = cp.deepcopy(nodes)
    aux_nodes = cp.deepcopy(nodes)
    for idx, label in enumerate(labels):
        if label == 1:
            pos_nodes.append(aux_nodes[idx])
            neg_nodes.remove(aux_nodes[idx])

    return pos_nodes, neg_nodes


def undersample(pos_nodes, neg_nodes, scale=1):
    """
    Under-sample the negative nodes
    :param pos_nodes: a list of positive nodes
    :param neg_nodes: a list negative nodes
    :param scale: the under-sampling scale
    :return: a list of under-sampled batch nodes
    """

    aux_nodes = cp.deepcopy(neg_nodes)
    aux_nodes = rd.sample(aux_nodes, k=int(len(pos_nodes)*scale))
    batch_nodes = pos_nodes + aux_nodes

    return batch_nodes

def calculate_g_mean(y_true, y_pred):
    cm = confusion_matrix(y_true, y_pred)
    sensitivities = []
    for i in range(len(cm)):
        TP = cm[i, i]
        FN = cm[i, :].sum() - TP
        sensitivity = TP / (TP + FN) if (TP + FN) != 0 else 0
        sensitivities.append(sensitivity)
    g_mean = np.prod(sensitivities) ** (1 / len(sensitivities))
    return g_mean

class MaskGenerator:
    def __init__(self, mask_ratio=0.3):
        self.mask_ratio = mask_ratio
        
    def __call__(self, num_nodes):
        """
        Args:
            num_nodes: 图中节点总数
        Returns:
            mask: (num_nodes,) 布尔tensor，True表示需要mask的节点
        """
        mask = torch.zeros(num_nodes)
        num_mask = int(num_nodes * self.mask_ratio)
        mask_idx = torch.randperm(num_nodes)[:num_mask]
        mask[mask_idx] = 1
        return mask.bool()

def get_unlabeled_batch_indices(unlabeled_idx, num_batches, batch_idx, shuffle=True):
    """专门为无标签数据设计的批次采样函数
    Args:
        unlabeled_idx: 无标签数据的索引列表
        num_batches: 总批次数
        batch_idx: 当前批次索引
        shuffle: 是否随机打乱
    Returns:
        batch_nodes: 当前批次的无标签节点索引列表
    """
    # 转换为numpy数组以便操作
    unlabeled_idx = np.array(unlabeled_idx)
    total_unlabeled = len(unlabeled_idx)
    
    # 如果需要随机打乱
    if shuffle:
        np.random.shuffle(unlabeled_idx)
    
    # 计算每个批次的无标签数据大小
    batch_size = total_unlabeled // num_batches
    remainder = total_unlabeled % num_batches
    
    # 计算当前批次的起始和结束索引
    start_idx = batch_idx * batch_size + min(batch_idx, remainder)
    end_idx = start_idx + batch_size + (1 if batch_idx < remainder else 0)
    
    # 获取当前批次的节点
    batch_nodes = unlabeled_idx[start_idx:end_idx].tolist()
    
    return batch_nodes

def load_data(data,k=2, prefix=''):
    """
    Load graph, feature, and label given dataset name
    """
    pickle_file = {}
    matrix_prefix = {}
    for key in filelist: # update the file paths
        pickle_file[key] = os.path.join(prefix, filelist[key])
        matrix_prefix[key] = os.path.join(prefix, file_matrix_prefix[key])
    
    if data == 'yelp':
        data_file = loadmat(os.path.join(prefix, 'YelpChi.mat'))
        labels = data_file['label'].flatten()
        feat_data = data_file['features'].todense().A
        
        with open(pickle_file['yelp_rur'], 'rb') as file:
            relation1 = pickle.load(file)
        file.close()
        relation1 = dict_to_edge_index(relation1)
        relation1_tree = []
        for i in range(1,k+1):
            file_name = '{}{}.pkl'.format(matrix_prefix['yelp_rur'], i)
            with open(file_name,'rb') as file:
                tree = pickle.load(file)
            file.close()
            relation1_tree.append(numpy_array_to_edge_index(tree))
        with open(pickle_file['yelp_rtr'], 'rb') as file:
            relation2 = pickle.load(file)
        file.close()
        relation2 = dict_to_edge_index(relation2)
        relation2_tree = []
        for i in range(1,k+1):
            file_name = '{}{}.pkl'.format(matrix_prefix['yelp_rtr'], i)
            with open(file_name,'rb') as file:
                tree = pickle.load(file)
            file.close()
            relation2_tree.append(numpy_array_to_edge_index(tree))
        with open(pickle_file['yelp_rsr'], 'rb') as file:
            relation3 = pickle.load(file)
        file.close()
        relation3 = dict_to_edge_index(relation3)
        relation3_tree = []
        for i in range(1,k+1):
            file_name = '{}{}.pkl'.format(matrix_prefix['yelp_rsr'], i)
            with open(file_name,'rb') as file:
                tree = pickle.load(file)
            file.close()
            relation3_tree.append(numpy_array_to_edge_index(tree))
        return [[relation1,relation1_tree],[relation2,relation2_tree],[relation3,relation3_tree]],feat_data,labels
    elif data == 'amazon':
        data_file = loadmat(os.path.join(prefix, 'Amazon.mat'))
        labels = data_file['label'].flatten()
        feat_data = data_file['features'].todense().A
        # load the preprocessed adj_lists
        with open(pickle_file['amz_upu'], 'rb') as file:
            relation1 = pickle.load(file)
        file.close()
        relation1 = dict_to_edge_index(relation1)
        relation1_tree = []
        for i in range(1,k+1):
            file_name = '{}{}.pkl'.format(matrix_prefix['amz_upu'], i)
            with open(file_name,'rb') as file:
                tree = pickle.load(file)
            file.close()
            relation1_tree.append(numpy_array_to_edge_index(tree))
        with open(pickle_file['amz_usu'], 'rb') as file:
            relation2 = pickle.load(file)
        file.close()
        relation2 =  dict_to_edge_index(relation2)
        relation2_tree = []
        for i in range(1,k+1):
            file_name = '{}{}.pkl'.format(matrix_prefix['amz_usu'], i)
            with open(file_name,'rb') as file:
                tree = pickle.load(file)
            file.close()
            relation2_tree.append(numpy_array_to_edge_index(tree))
        with open(pickle_file['amz_uvu'], 'rb') as file:
            relation3 = pickle.load(file)
        file.close()
        relation3_tree = []
        for i in range(1,k+1):
            file_name = '{}{}.pkl'.format(matrix_prefix['amz_uvu'], i)
            with open(file_name,'rb') as file:
                tree = pickle.load(file)
            file.close()
            relation3_tree.append(numpy_array_to_edge_index(tree))
        relation3 = dict_to_edge_index(relation3)
        
        return [[relation1,relation1_tree],[relation2,relation2_tree],[relation3,relation3_tree]],feat_data,labels
    elif data == 'ffsd':
        # 加载节点特征和标签
        try:
            print("从S-FFSD.mat加载数据...")
            data_file = loadmat(os.path.join(prefix, 'S-FFSD.mat'))
            labels = data_file['label'].flatten()
            feat_data = data_file['features'].todense().A
            print(f"加载了特征矩阵，形状: {feat_data.shape}")
        except Exception as e:
            print(f"加载.mat文件出错: {e}")
            # 如果.mat文件不存在，尝试从CSV加载
            import pandas as pd
            print("尝试从CSV加载...")
            df = pd.read_csv(os.path.join(prefix, 'S-FFSD.csv'))
            
            # 从CSV构建特征和标签
            labels = df['Labels'].values
            # 从其他列构建特征
            features = df.drop(columns=['Labels'])
            
            # 处理非数值特征
            from sklearn.preprocessing import LabelEncoder
            for col in features.select_dtypes(include=['object']).columns:
                le = LabelEncoder()
                features[col] = le.fit_transform(features[col].astype(str))
            
            feat_data = features.values
            print(f"从CSV加载了特征矩阵，形状: {feat_data.shape}")
        
        # 加载交易网络数据 - 从Source到Target的交易关系
        try:
            print("加载邻接列表...")
            # 尝试加载预处理的邻接列表
            with open(os.path.join(prefix, 'sffsd_adjlists.pickle'), 'rb') as file:
                relation = pickle.load(file)
            print("成功加载sffsd_adjlists.pickle")
        except:
            print("未找到预处理的邻接列表，尝试从CSV构建...")
            # 如果没有预处理的邻接列表，从CSV构建
            import pandas as pd
            df = pd.read_csv(os.path.join(prefix, 'S-FFSD.csv'))
            
            # 构建从Source到Target的边字典
            relation = {}
            for _, row in df.iterrows():
                source = int(row['Source']) if isinstance(row['Source'], (int, float)) else hash(str(row['Source'])) % 100000
                target = int(row['Target']) if isinstance(row['Target'], (int, float)) else hash(str(row['Target'])) % 100000
                
                if source not in relation:
                    relation[source] = []
                relation[source].append(target)
            
            # 保存构建的邻接列表以备将来使用
            with open(os.path.join(prefix, 'sffsd_adjlists.pickle'), 'wb') as file:
                pickle.dump(relation, file)
            print(f"已从CSV构建并保存邻接列表，包含 {len(relation)} 个源节点")
        
        # 转换为边索引格式
        relation = dict_to_edge_index(relation)
        print(f"边索引形状: {relation.shape}")
        
        # 加载或构建树结构
        relation_tree = []
        for i in range(1, k+1):  # 数据目录中最多有7个矩阵文件
            try:
                file_name = os.path.join(prefix, f'sffsd_matrix_{i}.pkl')
                print(f"尝试加载矩阵 {i}...")
                with open(file_name, 'rb') as file:
                    tree = pickle.load(file)
                relation_tree.append(numpy_array_to_edge_index(tree))
                print(f"矩阵 {i} 加载完成")
            except Exception as e:
                print(f"加载矩阵 {i} 失败: {e}")
                # 无法加载特定的树结构，跳过
                continue
        
        print(f"成功加载了 {len(relation_tree)} 个树结构矩阵")
        return [[relation, relation_tree]], feat_data, labels
    elif data=='CCFD':
        assert False,'CCFD dataset is secret, please contact the author for the dataset.'
        
        data_file= loadmat(os.path.join(prefix, 'CCFD.mat'))
        labels = data_file['labels'].flatten()
        feat_data = data_file['features']
        with open('../data/net_source_CCFD.pickle', 'rb') as file:
            relation1 = pickle.load(file)
        file.close()
        relation1 = dict_to_edge_index(relation1)
        relation1_tree = []
        for i in range(1,k+1):
            file_name = f'../data/CCFD_r1_matrix_{k}.pkl'
            with open(file_name,'rb') as file:
                tree = pickle.load(file)
            file.close()
            relation1_tree.append(numpy_array_to_edge_index(tree))	
        return [[relation1,relation1_tree]],feat_data,labels

def make_loader(indices, labels, batch_size, is_training=True, unlabeled=False):
    """创建数据加载器
    Args:
        indices: 数据索引
        labels: 标签数组
        batch_size: 批次大小
        is_training: 是否为训练模式
        unlabeled: 是否为无标签数据
    """
    class GraphDataset(torch.utils.data.Dataset):
        def __init__(self, indices, labels, unlabeled=False):
            self.indices = indices
            self.labels = labels
            self.unlabeled = unlabeled
            
        def __len__(self):
            return len(self.indices)
            
        def __getitem__(self, idx):
            node_idx = self.indices[idx]
            if self.unlabeled:
                return node_idx
            else:
                return node_idx, self.labels[node_idx]
                
        def get_class_distribution(self):
            if self.unlabeled:
                return None
            class_counts = np.bincount(self.labels[self.indices])
            return class_counts / len(self.indices)

    # 创建数据集
    dst = GraphDataset(indices, labels, unlabeled=unlabeled)
    
    # 创建数据加载器
    if is_training:
        return torch.utils.data.DataLoader(
            dst,
            batch_size=batch_size,
            shuffle=True,
            pin_memory=True,
            num_workers=2
        )
    else:
        return torch.utils.data.DataLoader(
            dst,
            batch_size=batch_size,
            shuffle=False,
            pin_memory=True
        )
    
def consistency_kl_loss(feat_A, feat_B, kl_distance, T=2.0, device=None):
    """
    计算特征的KL一致性损失
    Args:
        feat_A, feat_B: 两个模型的特征
        kl_distance: KLDivLoss实例
        T: temperature参数
        device: GPU设备
    """
   
    T = torch.tensor(T, device=device)
    eps = torch.tensor(1e-8, device=device)
    
    # 归一化特征
    feat_A = F.normalize(feat_A, p=2, dim=1)
    feat_B = F.normalize(feat_B, p=2, dim=1)
    
    # 转换为概率分布（在GPU上）
    p_A = F.softmax(feat_A/T, dim=1)
    p_B = F.softmax(feat_B/T, dim=1)
    
    # 计算对称KL散度
    loss = (kl_distance(torch.log(p_A + eps), p_B) + 
            kl_distance(torch.log(p_B + eps), p_A)) / 2
            
    return loss * (T**2)

def degree_drop_weights(edge_index):
    """计算基于度的边权重"""
    edge_index_ = to_undirected(edge_index)
    deg = degree(edge_index_[1])
    deg_col = deg[edge_index[1]].to(torch.float32)
    s_col = torch.log(deg_col)
    weights = (s_col.max() - s_col) / (s_col.max() - s_col.mean())
    return weights

def pr_drop_weights(edge_index, aggr='sink', k=10):
    """计算基于PageRank的边权重"""
    pv = compute_pr(edge_index, k=k)
    pv_row = pv[edge_index[0]].to(torch.float32)
    pv_col = pv[edge_index[1]].to(torch.float32)
    s_row = torch.log(pv_row)
    s_col = torch.log(pv_col)
    if aggr == 'sink':
        s = s_col
    elif aggr == 'source':
        s = s_row
    elif aggr == 'mean':
        s = (s_col + s_row) * 0.5
    else:
        s = s_col
    weights = (s.max() - s) / (s.max() - s.mean())
    return weights

def compute_pr(edge_index, damp: float = 0.85, k: int = 10):
    num_nodes = edge_index.max().item() + 1
    deg_out = degree(edge_index[0])
    x = torch.ones((num_nodes, )).to(edge_index.device).to(torch.float32)

    for i in range(k):
        edge_msg = x[edge_index[0]] / deg_out[edge_index[0]]
        # 使用原生 PyTorch 操作替代 scatter
        agg_msg = torch.zeros_like(x)
        for i in range(edge_index.shape[1]):
            agg_msg[edge_index[1][i]] += edge_msg[i]
        x = (1 - damp) * x + damp * agg_msg

    return x

def get_augmented_view(edge_indexs, feat_data, aug_type, drop_rate=0.2):
    """获取指定类型的图增强视图，适配HOGRL的多层图结构
    Args:
        edge_indexs: 原始图的多层边索引
        feat_data: 节点特征
        aug_type: 增强类型 ['edge_drop', 'feat_drop', 'degree', 'pr', 'weighted_feat']
        drop_rate: 删除比例
    Returns:
        如果是边增强: 返回增强后的多层图结构
        如果是特征增强: 返回 (原始边索引, 增强后的特征)
    """
    if aug_type == 'feat_drop':
        # 特征删除
        feat_mask = torch.rand(feat_data.size(1)) > drop_rate
        feat_aug = feat_data.clone()
        feat_aug[:, ~feat_mask] = 0
        return edge_indexs, feat_aug
        
    elif aug_type == 'weighted_feat':
        # 加权特征删除
        node_deg = degree(edge_indexs[0][0][1])  # 使用第一个关系的主图计算节点度
        feat_weights = feature_drop_weights(feat_data, node_deg)
        feat_aug = drop_feature_weighted(feat_data, feat_weights, drop_rate)
        return edge_indexs, feat_aug
    
    # 以下是边增强的逻辑
    augmented_edge_indexs = []
    
    for i, edge_index in enumerate(edge_indexs):
        if aug_type == 'edge_drop':
            # 随机边删除
            edge_mask = torch.rand(edge_index[0].size(1)) > drop_rate
            edge_index_main = edge_index[0][:, edge_mask]
            edge_index_trees = [tree_edge[:, torch.rand(tree_edge.size(1)) > drop_rate] 
                              for tree_edge in edge_index[1]]
                
        elif aug_type == 'degree':
            # 基于度的加权边删除
            drop_weights = degree_drop_weights(edge_index[0])
            edge_index_main = drop_edge_weighted(edge_index[0], drop_weights, p=drop_rate)
            
            edge_index_trees = []
            for tree_edge in edge_index[1]:
                tree_weights = degree_drop_weights(tree_edge)
                edge_index_trees.append(drop_edge_weighted(tree_edge, tree_weights, p=drop_rate))
                
        elif aug_type == 'pr':
            # PageRank加权边删除
            drop_weights = pr_drop_weights(edge_index[0], aggr='sink', k=10)
            edge_index_main = drop_edge_weighted(edge_index[0], drop_weights, p=drop_rate)
            
            edge_index_trees = []
            for tree_edge in edge_index[1]:
                tree_weights = pr_drop_weights(tree_edge, aggr='sink', k=10)
                edge_index_trees.append(drop_edge_weighted(tree_edge, tree_weights, p=drop_rate))
                
        else:
            raise ValueError(f"不支持的增强类型: {aug_type}")
            
        augmented_edge_indexs.append([edge_index_main, edge_index_trees])
    
    return feat_data, augmented_edge_indexs

def drop_edge_weighted(edge_index, edge_weights, p: float, threshold: float = 0.7):
    """基于权重的边删除
    Args:
        edge_index: 边索引
        edge_weights: 边权重
        p: 删除概率
        threshold: 权重阈值
    """
    edge_weights = edge_weights / edge_weights.mean() * p
    edge_weights = edge_weights.where(edge_weights < threshold, 
                                    torch.ones_like(edge_weights) * threshold)
    sel_mask = torch.bernoulli(1. - edge_weights).to(torch.bool)
    return edge_index[:, sel_mask]

def get_A_bounds(edge_index, drop_rate, local_changes, dataset):
    """获取邻接矩阵的上下界
    
    Args:
        edge_index: 图的边索引
        drop_rate: 边删除率
        local_changes: 局部变化程度
        dataset: 数据集名称
    """
    # 尝试加载已存在的边界
    # bounds_dir = '/data/hali/KDD/antifraud/data/amazon/bounds'
    # upper_lower_file = osp.join(bounds_dir, f"{dataset}_{drop_rate}_upper_lower.pkl")
    # if osp.exists(upper_lower_file):
    #     return torch.load(upper_lower_file)
    
    # 计算新的边界
    deg = degree(to_undirected(edge_index)[1]).cpu().numpy()
    A = to_scipy_sparse_matrix(edge_index).tocsr()
    A_tilde = A + sp.eye(A.shape[0])
    
    # 计算上界
    degs_tilde = deg + 1
    max_delete = np.maximum(degs_tilde.astype("int") - 2, 0)
    max_delete = np.minimum(max_delete, np.round(local_changes).astype("int"))
    sqrt_degs_tilde_max_delete = 1 / np.sqrt(degs_tilde - max_delete)
    A_upper = sqrt_degs_tilde_max_delete * sqrt_degs_tilde_max_delete[:, None]
    A_upper = np.where(A_tilde.toarray() > 0, A_upper, np.zeros_like(A_upper))
    A_upper = np.float32(A_upper)
    
    # 计算下界
    new_edge_index, An = gcn_norm(edge_index, num_nodes=A.shape[0])
    An = to_dense_adj(new_edge_index, edge_attr=An)[0].cpu().numpy()
    A_lower = np.zeros_like(An)
    A_lower[np.diag_indices_from(A_lower)] = np.diag(An)
    A_lower = np.float32(A_lower)
    
    # 保存边界
    #torch.save((A_upper, A_lower), upper_lower_file)
    
    return A_upper, A_lower

def seed_everything(seed: int):
    """设置随机种子
    
    Args:
        seed: 随机种子
    """
    import random
    import os
    import numpy as np
    import torch
    
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def get_drop_weights(edge_indexs, scheme='degree'):
    """计算边的删除权重
    
    Args:
        edge_indexs: HOGRL的多层图结构
        scheme: 权重计算方案 ('degree' or 'pr')
    Returns:
        drop_weights_list: 每层图的边删除权重列表
    """
    drop_weights_list = []
    
    for edge_index in edge_indexs:
        # 处理主图
        if scheme == 'degree':
            # 基于度的权重
            main_weights = degree_drop_weights(edge_index[0])
        elif scheme == 'pr':
            # 基于PageRank的权重
            main_weights = pr_drop_weights(edge_index[0], aggr='sink', k=10)
        else:
            main_weights = None
            
        # 处理树结构
        tree_weights = []
        for tree_edge in edge_index[1]:
            if scheme == 'degree':
                weights = degree_drop_weights(tree_edge)
            elif scheme == 'pr':
                weights = pr_drop_weights(tree_edge, aggr='sink', k=10)
            else:
                weights = None
            tree_weights.append(weights)
            
        drop_weights_list.append([main_weights, tree_weights])
    
    return drop_weights_list

def get_crown_weights(l1, u1, l2, u2, alpha, gcn_weights, Wcl):
    """计算CROWN权重
    Args:
        l1, u1: 第一层的下界和上界 (256, 64)
        l2, u2: 第二层的下界和上界 (256, 64)
        alpha: 激活函数参数
        gcn_weights: GCN层的权重 [W1(25,64), b1(64), W2(64,64), b2(64)]
        Wcl: 对比学习权重 (256, 192)
    """
    
    # 1. 首先将Wcl投影到正确的维度
    projection = nn.Linear(Wcl.shape[1], l2.shape[1], device=Wcl.device)
    Wcl = projection(Wcl)  # 现在Wcl的形状是(256, 64)
    
    # 2. 计算alpha和beta
    alpha_2_L, alpha_2_U, beta_2_L, beta_2_U = get_alpha_beta(l2, u2, alpha)
    alpha_1_L, alpha_1_U, beta_1_L, beta_1_U = get_alpha_beta(l1, u1, alpha)
    
    # 3. 现在维度匹配，可以安全使用torch.where
    lambda_2 = torch.where(Wcl >= 0, alpha_2_L, alpha_2_U)  # (256, 64)
    Delta_2 = torch.where(Wcl >= 0, beta_2_L, beta_2_U)    # (256, 64)
    Lambda_2 = lambda_2 * Wcl                              # (256, 64)
    
    # 4. 解包权重
    W1_tensor, b1_tensor, W2_tensor, b2_tensor = gcn_weights
    
    # 5. 计算变换后的权重
    W_tilde_2 = Lambda_2 @ W2_tensor.T                     # (256, 64)
    b_tilde_2 = torch.diag(Lambda_2 @ (Delta_2 + b2_tensor).T)
    
    lambda_1 = torch.where(W_tilde_2 >= 0, alpha_1_L, alpha_1_U)
    Delta_1 = torch.where(W_tilde_2 >= 0, beta_1_L, beta_1_U)
    Lambda_1 = lambda_1 * W_tilde_2
    
    W_tilde_1 = Lambda_1 @ W1_tensor.T
    b_tilde_1 = torch.diag(Lambda_1 @ (Delta_1 + b1_tensor).T)
    
    return W_tilde_1, b_tilde_1, W_tilde_2, b_tilde_2

def get_alpha_beta(l, u, alpha):
    alpha_L= torch.zeros(l.shape,device=l.device)
    alpha_U, beta_L, beta_U = torch.clone(alpha_L), torch.clone(alpha_L), torch.clone(alpha_L)
    pos_mask = l >= 0
    neg_mask = u <= 0
    alpha_L[pos_mask] = 1
    alpha_U[pos_mask] = 1
    alpha_L[neg_mask] = alpha
    alpha_U[neg_mask] = alpha
    not_mask = ~(pos_mask | neg_mask)
    alpha_not_upp = u[not_mask] - alpha * l[not_mask]
    alpha_not = alpha_not_upp / (u[not_mask] - l[not_mask])
    alpha_L[not_mask] = alpha_not
    alpha_U[not_mask] = alpha_not
    beta_U[not_mask] = (alpha - 1) * u[not_mask] * l[not_mask] / alpha_not_upp
    return alpha_L, alpha_U, beta_L, beta_U

def timer(func):
    """函数计时装饰器"""
    @wraps(func)
    def wrapper(*args, **kwargs):
        start_time = time.time()
        result = func(*args, **kwargs)
        end_time = time.time()
        print(f'{func.__name__} 运行时间: {end_time - start_time:.4f}秒')
        return result
    return wrapper


def drop_feature_weighted(x, w, p: float, threshold: float = 0.7):
    """基于权重的特征删除"""
    w = w / w.mean() * p
    w = w.where(w < threshold, torch.ones_like(w) * threshold)
    drop_prob = w.repeat(x.size(0)).view(x.size(0), -1)

    drop_mask = torch.bernoulli(drop_prob).to(torch.bool)

    x = x.clone()
    x[drop_mask] = 0.

    return x

def feature_drop_weights(x, node_c):
    """计算特征删除权重"""
    x = x.to(torch.bool).to(torch.float32)
    w = x.t() @ node_c
    w = w.log()
    s = (w.max() - w) / (w.max() - w.mean())

    return s

def compute_consistency_loss(h1, h2, out1, out2, batch_nodes_tensor, temperature=0.5):
    # 特征级一致性
    feature_consistency = F.mse_loss(h1[batch_nodes_tensor], h2[batch_nodes_tensor])

    # 预测级一致性
    pred_consistency = F.kl_div(
        F.log_softmax(out1[batch_nodes_tensor] / temperature, dim=1),
        F.softmax(out2[batch_nodes_tensor] / temperature, dim=1),
        reduction='batchmean'
    ) + F.kl_div(
        F.log_softmax(out2[batch_nodes_tensor] / temperature, dim=1),
        F.softmax(out1[batch_nodes_tensor] / temperature, dim=1),
        reduction='batchmean'
    )
    #检查两个loss哪个输出Nan
    if torch.isnan(feature_consistency):
        print("feature_consistency is Nan")
        return pred_consistency
    elif torch.isnan(pred_consistency):
        print("pred_consistency is Nan")
        return feature_consistency
    else:
        return feature_consistency + pred_consistency
